import networkx as nx
import csv
import math
import numpy as np
import torch

from flow import *

def read_net(road_net_filename, lcc=False):
    '''
        Reads network from csv, optionally gets
        largest connected component.
    '''
    G = nx.DiGraph()

    with open(road_net_filename, 'r') as file_in:
        reader = csv.reader(file_in)

        for r in reader:
            u = r[0]
            v = r[1]

            G.add_edge(u,v)

    if lcc:
        LWCC = sorted(nx.weakly_connected_components(G), key = len, reverse=True)[0]
        return G.subgraph(LWCC)
    else:
        return G

def normalize_features(features):
    """Normalizes features using standard scaler"""
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    
    n_features = features[list(features.keys())[0]].shape[0]
    
    feat_matrix = np.zeros((len(features), n_features))
    
    i = 0
    for e in features:
        feat_matrix[i] = features[e]
        i = i + 1
    
    scaler.fit(feat_matrix)
    feat_matrix = scaler.transform(feat_matrix)
    
    norm_features = {}
    
    i = 0
    for e in features:
        norm_features[e] = feat_matrix[i]
        i = i + 1
    
    return norm_features


def make_non_neg_norm(G, flows, features):
    '''
        Converts flow estimation instance to a non-negative
        one, i.e. where every flow is non-negative.
    '''
    new_flows = {}
    new_G = nx.DiGraph()
    new_feat = {}

    max_flow = np.max(list(flows.values()))

    for e in G.edges():
        if e in flows:
            if flows[e] < 0:
                new_e = (e[1],e[0])
                new_flows[new_e] = -flows[e] / max_flow
                new_G.add_edge(e[1],e[0])
                new_feat[new_e] = features[e]
            else:
                new_G.add_edge(e[0],e[1])
                new_flows[e] = flows[e] / max_flow
                new_feat[e] = features[e]
        else:
            new_G.add_edge(e[0],e[1])
            new_feat[e] = features[e]

    return new_G, new_flows, new_feat

invphi = (math.sqrt(5) - 1) / 2  # 1 / phi
invphi2 = (3 - math.sqrt(5)) / 2  # 1 / phi^2

def gss(f, args, a, b, tol=1e-5):
    '''Golden section search.

    Given a function f with a single local minimum in
    the interval [a,b], gss returns a subset interval
    [c,d] that contains the minimum with d-c <= tol.

    modified from: https://en.wikipedia.org/wiki/Golden-section_search
    
    Usage: gss(f_gss, [G_reg_2, ups, super_regions, updates_proj, .5, False, recall], 0., 1.)
    '''

    (a, b) = (min(a, b), max(a, b))
    h = b - a
    if h <= tol:
        return (a, b)

    # Required steps to achieve tolerance
    n = int(math.ceil(math.log(tol / h) / math.log(invphi)))

    c = a + invphi2 * h
    d = a + invphi * h
    yc = f(c, args)
    yd = f(d, args)

    for k in range(n-1):
        if yc < yd:
            b = d
            d = c
            yd = yc
            h = invphi * h
            c = a + invphi2 * h
            yc = f(c, args)
        else:
            a = c
            c = d
            yc = yd
            h = invphi * h
            d = a + invphi * h
            yd = f(d, args)

    if yc < yd:
        return (a, d)
    else:
        return (c, b)


def sparse_tensor_from_coo_matrix(matrix):
    """
    Converts a SciPy sparse COO matrix to a torch sparse COO tensor using the recommended API.
    """
    # Extract indices and values from the COO matrix.
    coo = matrix.tocoo()
    indices = np.vstack((coo.row, coo.col))
    values = coo.data
    shape = coo.shape
    indices_tensor = torch.tensor(indices, dtype=torch.long)
    values_tensor = torch.tensor(values, dtype=torch.float)
    return torch.sparse_coo_tensor(indices_tensor, values_tensor, torch.Size(shape))



def identify_sources_sinks(G, flows):
    """
    Identifies source and sink nodes in the graph based on flow data.
    
    Parameters:
        G (networkx.Graph): The flow graph.
        flows (dict): Dictionary mapping edges to their flow values.
        
    Returns:
        sources (set): Set of source nodes.
        sinks (set): Set of sink nodes.
    """
    sources = set()
    sinks = set()
    
    # Initialize dictionaries to hold total incoming and outgoing flows for each node
    incoming_flows = {node: 0.0 for node in G.nodes()}
    outgoing_flows = {node: 0.0 for node in G.nodes()}
    
    for edge, flow in flows.items():
        u, v = edge
        outgoing_flows[u] += flow
        incoming_flows[v] += flow
    
    for node in G.nodes():
        net_flow = outgoing_flows[node] - incoming_flows[node]
        if net_flow > 1e-6:  # Adding a small threshold to account for numerical precision
            sources.add(node)
        elif net_flow < -1e-6:
            sinks.add(node)
    
    return sources, sinks

def add_source_sink_features(G, features, sources, sinks):
    """
    Adds two binary features to each edge indicating if the from-node is a source
    or the to-node is a sink.
    
    Parameters:
        G (networkx.Graph): The flow graph.
        features (dict): Existing edge features.
        sources (set): Set of source nodes.
        sinks (set): Set of sink nodes.
        
    Returns:
        updated_features (dict): Updated edge features with source and sink indicators.
    """
    updated_features = {}
    
    for e in G.edges():
        u, v = e
        is_source = 1.0 if u in sources else 0.0
        is_sink = 1.0 if v in sinks else 0.0
        
        orig_feat = features.get(e, np.array([]))  # Handle missing features gracefully
        
        # If original features are empty, initialize with zeros
        if orig_feat.size == 0:
            orig_feat = np.zeros(0)
        
        new_feat = np.concatenate([orig_feat, [is_source, is_sink]])
        updated_features[e] = new_feat
    
    return updated_features


def compute_laplacian_eigenvectors(G, k=16):
    """
    Computes the first k non-trivial Laplacian eigenvectors for positional encoding.
    
    Parameters:
        G (networkx.Graph): The input graph.
        k (int): Number of eigenvectors to compute.
        
    Returns:
        pos_enc (dict): Mapping from node to its positional encoding vector.
    """
    # Compute the normalized Laplacian
    L = nx.normalized_laplacian_matrix(G).todense()
    
    # Compute the first k+1 eigenvalues and eigenvectors
    # (the first eigenvector is trivial, corresponding to eigenvalue 0)
    eigenvalues, eigenvectors = eigh(L, subset_by_index=[1, k])
    
    # Normalize eigenvectors
    eigenvectors = normalize(eigenvectors, axis=1)
    
    pos_enc = {}
    for idx, node in enumerate(G.nodes()):
        pos_enc[node] = eigenvectors[idx]
    
    return pos_enc

def add_positional_encodings(G, features, pos_enc):
    """
    Appends positional encodings to existing node features.
    
    Parameters:
        G (networkx.Graph): The input graph.
        features (dict): Existing edge features.
        pos_enc (dict): Positional encodings per node.
        
    Returns:
        updated_features (dict): Edge features with appended positional encodings.
    """
    updated_features = {}
    
    for e in G.edges():
        u, v = e
        # Retrieve positional encodings for source and target nodes
        u_pos = pos_enc[u]
        v_pos = pos_enc[v]
        
        # If edge features already exist, concatenate; else, initialize
        orig_feat = features.get(e, np.array([]))
        if orig_feat.size == 0:
            orig_feat = np.zeros(0)
        
        # Example: Concatenate source and target positional encodings
        new_feat = np.concatenate([orig_feat, u_pos, v_pos])
        updated_features[e] = new_feat
    
    return updated_features


def add_edge_centrality_features(G, features, measure='betweenness'):
    """
    Computes node-level centrality for each node in G and appends the centrality
    of the edge's source and target to the edge feature vector.

    :param G: networkx.DiGraph (or Graph)
    :param features: dict {edge: np.array([...])}, existing edge features
    :param measure: str, which centrality measure to use ('betweenness', 'closeness', etc.)
    :return: dict {edge: np.array([...])}, updated edge features with 2 extra dims
    """
    # 1. Compute node-level centrality
    if measure == 'betweenness':
        node_centralities = nx.betweenness_centrality(G)
    elif measure == 'closeness':
        node_centralities = nx.closeness_centrality(G)
    elif measure == 'degree':
        # For directed graphs, you might use in_degree_centrality or out_degree_centrality
        node_centralities = nx.degree_centrality(G)
    else:
        raise ValueError(f"Unsupported centrality measure: {measure}")

    updated_features = {}

    # 2. Append the centralities of (u,v) to each edge’s feature
    for e in G.edges():
        u, v = e
        old_feat = features[e]
        c_u = node_centralities[u]
        c_v = node_centralities[v]

        # Combine them (2 new dimensions)
        new_feat = np.concatenate([old_feat, [c_u, c_v]])
        updated_features[e] = new_feat

    return updated_features
